【深度学习-RNN系列】GRU简介 & 源码实现

简介

LSTM 计算较为复杂,参数也非常多,难以训练。GRU(Gated Recurrent Units)应运而生。

在 GRU 中,大幅简化了 LSTM 结构

背后的逻辑

  1. 新增 reset gate,即图中的 r 开关;
  2. 将输入门和遗忘门合并为“update gate”,即图中的 z 开关;
  3. 将细胞状态 C 和隐藏状态 m 合并为 h;
  4. 省掉了输出门;

应用实例

  • 百度的Deep speech2

GRU的实现源码

GRU-tensorflow

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

def call(self, inputs, state):
"""Gated recurrent unit (GRU) with nunits cells."""

gate_inputs = math_ops.matmul(
array_ops.concat([inputs, state], 1), self._gate_kernel)
gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias)

value = math_ops.sigmoid(gate_inputs)
r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)

r_state = r * state

candidate = math_ops.matmul(
array_ops.concat([inputs, r_state], 1), self._candidate_kernel)
candidate = nn_ops.bias_add(candidate, self._candidate_bias)

c = self._activation(candidate)
new_h = u * state + (1 - u) * c
return new_h, new_h

pytorch

缺陷

参考